import math
import numpy as np


def warmup_lr(args, epoch, dl, it, optimizer, log):
    it_g = 1 + it + epoch * len(dl)  # global training iteration
    if args.crl:
        if args.warm and epoch <= args.warm_epochs: # supcon
            p = it_g / (args.warm_epochs * len(dl))
            lr = args.warmup_from + p * (args.warmup_to - args.warmup_from)
            lr = args.lr * it_g / float(args.warmup_iters) # jem
            if args.mcog_hier:
                for param_group in optimizer[0].param_groups:
                    param_group['lr'] = lr
                for param_group in optimizer[1].param_groups:
                    param_group['lr'] = lr
            else:
                for param_group in optimizer.param_groups:
                    param_group['lr'] = lr
            log.update(1, lr=lr)
    else:
        if it_g <= args.warmup_iters:
        # if args.warm and epoch <= args.warm_epochs: # supcon
            # p = it_g / (args.warm_epochs * len(dl))
            # lr = args.warmup_from + p * (args.warmup_to - args.warmup_from)
            lr = args.lr * it_g / float(args.warmup_iters) # jem
            if args.mcog_hier:
                for param_group in optimizer[0].param_groups:
                    param_group['lr'] = lr
                for param_group in optimizer[1].param_groups:
                    param_group['lr'] = lr
            else:
                for param_group in optimizer.param_groups:
                    param_group['lr'] = lr
            log.update(1, lr=lr)


def adjust_lr(args, epoch, optimizer, log):
    if args.cosine: # supcon
        eta_min = args.lr * (args.lr_decay_rate ** 3)
        lr = eta_min + (args.lr - eta_min) * (
                1 + math.cos(math.pi * epoch / args.epochs)) / 2
        # for param_group in optimizer.param_groups:
        #     param_group['lr'] = lr
        if args.mcog_hier:
            for param_group in optimizer[0].param_groups:
                param_group['lr'] = lr
            for param_group in optimizer[1].param_groups:
                param_group['lr'] = lr
        else:
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr

    if epoch in args.lr_decay_epochs:
        # steps = np.sum(epoch > np.asarray(args.lr_decay_epochs))
        # if steps > 0:
        #     lr = lr * (args.lr_decay_rate ** steps)
        lr = 0
        if args.mcog_hier:
            for param_group in optimizer[0].param_groups:
                lr = param_group['lr'] * args.lr_decay_rate
                param_group['lr'] = lr
            for param_group in optimizer[1].param_groups:
                lr = param_group['lr'] * args.lr_decay_rate
                param_group['lr'] = lr
        else:
            for param_group in optimizer.param_groups:
                lr = param_group['lr'] * args.lr_decay_rate
                param_group['lr'] = lr
        log.update(1, lr=lr)


# dino, trades
def other_adjust_lr(args, optimizer, epoch):
    """decrease the learning rate"""
    ## 현재 epoch에 대해서만 계산해서 반환
    lr = args.lr
    schedule = args.lr_schedule
    # schedule from TRADES repo (different from paper due to bug there)
    if schedule == 'trades':
        if epoch >= 0.75 * args.epochs:
            lr = args.lr * 0.1
    # schedule as in TRADES paper
    elif schedule == 'trades_fixed':
        if epoch >= 0.75 * args.epochs:
            lr = args.lr * 0.1
        if epoch >= 0.9 * args.epochs:
            lr = args.lr * 0.01
        if epoch >= args.epochs:
            lr = args.lr * 0.001
    # cosine schedule
    elif schedule == 'cosine':
        lr = args.lr * 0.5 * (1 + np.cos((epoch - 1) / args.epochs * np.pi))
    # schedule as in WRN paper
    elif schedule == 'wrn':
        if epoch >= 0.3 * args.epochs:
            lr = args.lr * 0.2
        if epoch >= 0.6 * args.epochs:
            lr = args.lr * 0.2 * 0.2
        if epoch >= 0.8 * args.epochs:
            lr = args.lr * 0.2 * 0.2 * 0.2
    else:
        raise ValueError('Unkown LR schedule %s' % schedule)
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr
    return lr


def cosine_scheduler(base_value, final_value, epochs, niter_per_ep, warmup_epochs=0, start_warmup_value=0):
    ## array로 모두 생성하여 반환
    warmup_schedule = np.array([])
    warmup_iters = warmup_epochs * niter_per_ep
    if warmup_epochs > 0:
        warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters)

    iters = np.arange(epochs * niter_per_ep - warmup_iters)
    schedule = final_value + 0.5 * (base_value - final_value) * (1 + np.cos(np.pi * iters / len(iters)))

    schedule = np.concatenate((warmup_schedule, schedule))
    assert len(schedule) == epochs * niter_per_ep
    return schedule


